{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `LibreASR`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 3\n",
    "# %matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from lib.imports import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `Config`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conf, lang, builder_train, builder_valid, db, m, learn = parse_and_apply_config()\n",
    "(conf, lang, db, m, learn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "builder_train.print()\n",
    "builder_valid.print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "builder_train.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tpl = db.one_batch()\n",
    "X, Ym, _, _ = tpl[0]\n",
    "Y, Y_lens, X_lens = tpl[1]\n",
    "what(X), what(X_lens), what(Y), what(Y_lens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_lens, Y_lens"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `Augmentation`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# db.aug(n=3) #0: signal, 7: spectro"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `Model`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# switch to gpu mode\n",
    "# m.convert_to_gpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = m.to(conf[\"cuda\"][\"device\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(m)\n",
    "print(\"n_params:\", n_params(m))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Timing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "_in = db.one_batch()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "with torch.no_grad():\n",
    "    _out = m(cudaize(_in))\n",
    "    print(\"X\", _in[0].shape, \"->\", _out.shape)\n",
    "    print(\"Y\", _in[1].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `Train`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt = \"model\"\n",
    "# ckpt = \"libreasr-fr-13-12-2020\"\n",
    "# ckpt = \"chkp-en-57-wer\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    learn.load(ckpt, with_opt=True, strict=False)\n",
    "    pass\n",
    "except Exception as e:\n",
    "    print(\"failed loading ckpt\", ckpt)\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `fit`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.unfreeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# learn.lr_find()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hypers\n",
    "do_warmup = True\n",
    "epochs = 20\n",
    "warmup_epochs = 2.9\n",
    "warmup = int(warmup_epochs + 1)\n",
    "\n",
    "lr = 5e-4\n",
    "wd = 0.01 # 0.1\n",
    "\n",
    "div = 500. # 100.\n",
    "div_final = 1.001\n",
    "\n",
    "pct_start = warmup_epochs / warmup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# warmup\n",
    "if do_warmup:\n",
    "    try:\n",
    "        learn.fit_one_cycle(warmup, lr, wd=wd, cbs=learn.extra_cbs, div_final=div_final, div=div, pct_start=pct_start)\n",
    "    except: pass\n",
    "    m = m.to(conf[\"cuda\"][\"device\"])\n",
    "    learn.save(ckpt, with_opt=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tune\n",
    "m = m.to(conf[\"cuda\"][\"device\"])\n",
    "learn.load(ckpt, with_opt=True, strict=False, device=conf[\"cuda\"][\"device\"])\n",
    "learn.fit(epochs, lr, wd=wd, cbs=learn.extra_cbs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise Exception(\"done\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `checks`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.recorder.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save model for later\n",
    "# learn.save(checkpoints[-1], with_opt=True)\n",
    "learn.save(ckpt, with_opt=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.save(\"libreasr-fr-13-12-2020\", with_opt=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `Decode`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sampls = 999999\n",
    "test_name = \"dev-all\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = list(learn.test(min_samples=sampls, train=False, save_best=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ylens = map(lambda x: x.get(\"text/ground_truth\"), _)\n",
    "ylens = filter(lambda x: x is not None and len(x) != 0, ylens)\n",
    "ylens = map(lambda x: len(x), ylens)\n",
    "ylens = np.array(list(ylens))\n",
    "\n",
    "wers = map(lambda x: x.get(\"metrics/wer\", None), _)\n",
    "wers = filter(lambda x: x is not None, wers)\n",
    "wers = np.array(list(wers))\n",
    "ylens[:3], wers[:3], len(ylens), len(wers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stack = np.stack([ylens, wers]).T\n",
    "stack.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(stack, columns=[\"ylen\", \"wer\"])\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grouped = df.groupby(\"ylen\").mean()[\"wer\"].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save plots\n",
    "plt.plot(grouped)\n",
    "plt.ylim(0.0, 1.2)\n",
    "plt.ylabel(\"WER\")\n",
    "plt.xlabel(\"Label Length\")\n",
    "plt.savefig(f\"{test_name}.png\", dpi=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save tests\n",
    "import pickle\n",
    "pickle.dump(_, open(f'{test_name}.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _sort(results, rev):\n",
    "    results = filter(lambda x: \"text/ground_truth\" in list(x.keys()), results)\n",
    "    results = sorted(results, key=lambda x: x[\"metrics/wer\"], reverse=rev)\n",
    "    return results\n",
    "\n",
    "def best(results):\n",
    "    return _sort(results, rev=False)\n",
    "\n",
    "def worst(results, n=25):\n",
    "    results = _sort(results, rev=True)[:n]\n",
    "    df = builder_valid.df\n",
    "    for res in results:\n",
    "        def filter_fn(row):\n",
    "            s1 = (res[\"text/ground_truth\"])\n",
    "            s2 = sanitize_str(row.label)\n",
    "            in_ok = s1 in s2\n",
    "            l_ok = abs(len(s1) - len(s2)) < 10\n",
    "            return in_ok and l_ok\n",
    "        rows = df[df.apply(filter_fn, axis=1)]\n",
    "        row = rows\n",
    "        print(len(row))\n",
    "        if len(row) >= 1:\n",
    "            row = row.iloc[0]\n",
    "            file = row.file\n",
    "            sr = int(row.sr)\n",
    "            xstart = int(row.xstart / 1000. * sr)\n",
    "            xlen = int(row.xlen / 1000. * sr)\n",
    "            data, sr = torchaudio.load(file, frame_offset=xstart, num_frames=xlen)\n",
    "            print(\"PRED:\", res[\"text/prediction\"])\n",
    "            print(\"TRUE:\", res[\"text/ground_truth\"])\n",
    "            print(\"WER :\", res[\"metrics/wer\"])\n",
    "            display(Audio(data.numpy(), rate=sr))\n",
    "        else: print(\"err\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "worst(_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.joint.joint[2].bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.joint.joint[2].bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.save(\"libreasr-en-10-12-2020\", with_opt=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Archive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from lib.model_utils import *\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_asr_model()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
